import numpy as np
import paddle
from paddle_quantum.circuit import UAnsatz


def U_theta(n, block, depth, theta, x):
    assert(n==x.shape[0])
    cir = UAnsatz(n)
    for i in range(block):
        # W(theta)
        for j in range(depth):
            for k in range(n):
                cir.rz(theta[i][j][k][0], k)
                cir.ry(theta[i][j][k][1], k)
            for k in range(n):
                cir.cnot([k, (k+1)%n])
        # S(x)
        cir.angle_encoding(x,'rz')

    # W(theta) L+1 trainable block
    for j in range(depth):
        for k in range(n):
            cir.rz(theta[-1][j][k][0], k)
            cir.ry(theta[-1][j][k][1], k)
        for k in range(n):
            cir.cnot([k, (k+1)%n])
    return cir


class QNN(paddle.nn.Layer):
    """
    """
    def __init__(self,
                 n,            # number of qubit
                 depth,        # depth of each trainable block
                 block,  
                 ):
        super(QNN, self).__init__()
        self.num_qubits = n
        self.depth = depth
        self.block = block
        
        # initial parameters
        self.theta = self.create_parameter(
            shape=[block+1, depth, n, 2],
            default_initializer=paddle.nn.initializer.Uniform(0.0, 2*np.pi),
            dtype='float64',
            is_bias=False)
    
    def forward(self, x):
        """
        """
        predict = []
        H_info = [[1.0, 'z%s'%i] for i in range(self.num_qubits)]
     
        for i in range(x.shape[0]):
            cir = U_theta(self.num_qubits, self.block, self.depth, self.theta, x[i])
            cir.run_state_vector()
            predict.append(cir.expecval(H_info, shots=0)*0.5+0.5) # （0，1）

        return paddle.concat(predict).reshape((-1,)), cir


def train_model(train_X, train_y, test_X, test_y, seed, N, DEPTH, BLOCK, EPOCH=10, BATCH_SIZE=40, LR=0.1):
    """
    """
    paddle.seed(seed)
    net = QNN(N, DEPTH, BLOCK)

    opt = paddle.optimizer.Adam(learning_rate=LR, parameters=net.parameters())

    train_loss, test_accuracy = [], []

    for epoch in range(EPOCH):
        for j in range(train_X.shape[0]//BATCH_SIZE):
            batch_X = train_X[j:(j+1)*BATCH_SIZE]
            batch_y = train_y[j:(j+1)*BATCH_SIZE]
            predict, cir = net(batch_X)

            if epoch==0 and j==0:
                print(cir)

            avg_loss = paddle.mean((predict - batch_y) ** 2)
            train_loss.append(avg_loss.numpy())

            print("Epoch:%s ----- batch:%s ----- training loss %s"%(epoch, j, avg_loss.numpy()[0]))

            avg_loss.backward()
            opt.minimize(avg_loss)
            opt.clear_grad()

        # compute test accuracy
        predict, _ = net(test_X)
        is_correct = (paddle.abs(predict - test_y) < 0.5).nonzero().shape[0]
        acc = is_correct / test_y.shape[0]
        print("---------------------------------------- test accuracy: %s"%acc)
        test_accuracy.append(acc)

    return train_loss, test_accuracy